{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import sys\n",
    "import os\n",
    "import math\n",
    "sys.path.append(os.path.abspath(\"../../data\"))\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "BATCH_SIZE = 1\n",
    "SP_VOCAB_SIZE = 1000\n",
    "TRAIN_SIZE = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
      "100%|██████████| 3/3 [00:00<00:00, 140.63it/s]\n",
      "sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=cnn_dailymail --vocab_size=1000 --model_type=unigram\n",
      "sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : \n",
      "trainer_spec {\n",
      "  input: tokens.txt\n",
      "  input_format: \n",
      "  model_prefix: cnn_dailymail\n",
      "  model_type: UNIGRAM\n",
      "  vocab_size: 1000\n",
      "  self_test_sample_size: 0\n",
      "  character_coverage: 0.9995\n",
      "  input_sentence_size: 0\n",
      "  shuffle_input_sentence: 1\n",
      "  seed_sentencepiece_size: 1000000\n",
      "  shrinking_factor: 0.75\n",
      "  max_sentence_length: 4192\n",
      "  num_threads: 16\n",
      "  num_sub_iterations: 2\n",
      "  max_sentencepiece_length: 16\n",
      "  split_by_unicode_script: 1\n",
      "  split_by_number: 1\n",
      "  split_by_whitespace: 1\n",
      "  split_digits: 0\n",
      "  treat_whitespace_as_suffix: 0\n",
      "  required_chars: \n",
      "  byte_fallback: 0\n",
      "  vocabulary_output_piece_score: 1\n",
      "  train_extremely_large_corpus: 0\n",
      "  hard_vocab_limit: 1\n",
      "  use_all_vocab: 0\n",
      "  unk_id: 0\n",
      "  bos_id: 1\n",
      "  eos_id: 2\n",
      "  pad_id: -1\n",
      "  unk_piece: <unk>\n",
      "  bos_piece: <s>\n",
      "  eos_piece: </s>\n",
      "  pad_piece: <pad>\n",
      "  unk_surface:  ⁇ \n",
      "}\n",
      "normalizer_spec {\n",
      "  name: nmt_nfkc\n",
      "  add_dummy_prefix: 1\n",
      "  remove_extra_whitespaces: 1\n",
      "  escape_whitespaces: 1\n",
      "  normalization_rule_tsv: \n",
      "}\n",
      "denormalizer_spec {}\n",
      "trainer_interface.cc(319) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.\n",
      "trainer_interface.cc(174) LOG(INFO) Loading corpus: tokens.txt\n",
      "trainer_interface.cc(375) LOG(INFO) Loaded all 2161 sentences\n",
      "trainer_interface.cc(390) LOG(INFO) Adding meta_piece: <unk>\n",
      "trainer_interface.cc(390) LOG(INFO) Adding meta_piece: <s>\n",
      "trainer_interface.cc(390) LOG(INFO) Adding meta_piece: </s>\n",
      "trainer_interface.cc(395) LOG(INFO) Normalizing sentences...\n",
      "trainer_interface.cc(456) LOG(INFO) all chars count=101295\n",
      "trainer_interface.cc(467) LOG(INFO) Done: 99.9526% characters are covered.\n",
      "trainer_interface.cc(477) LOG(INFO) Alphabet size=68\n",
      "trainer_interface.cc(478) LOG(INFO) Final character coverage=0.999526\n",
      "trainer_interface.cc(510) LOG(INFO) Done! preprocessed 2161 sentences.\n",
      "unigram_model_trainer.cc(138) LOG(INFO) Making suffix array...\n",
      "unigram_model_trainer.cc(142) LOG(INFO) Extracting frequent sub strings...\n",
      "unigram_model_trainer.cc(193) LOG(INFO) Initialized 11093 seed sentencepieces\n",
      "trainer_interface.cc(516) LOG(INFO) Tokenizing input sentences with whitespace: 2161\n",
      "trainer_interface.cc(526) LOG(INFO) Done! 6025\n",
      "unigram_model_trainer.cc(488) LOG(INFO) Using 6025 sentences for EM training\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=4758 obj=13.0431 num_tokens=13019 num_tokens/piece=2.73623\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=4170 obj=11.4914 num_tokens=13140 num_tokens/piece=3.15108\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=3123 obj=11.5758 num_tokens=14016 num_tokens/piece=4.48799\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=3120 obj=11.4633 num_tokens=14040 num_tokens/piece=4.5\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=2340 obj=11.859 num_tokens=15424 num_tokens/piece=6.59145\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=2340 obj=11.721 num_tokens=15431 num_tokens/piece=6.59444\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=1755 obj=12.3254 num_tokens=17151 num_tokens/piece=9.77265\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=1755 obj=12.1667 num_tokens=17155 num_tokens/piece=9.77493\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=1316 obj=12.8195 num_tokens=19020 num_tokens/piece=14.4529\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=1316 obj=12.6626 num_tokens=19020 num_tokens/piece=14.4529\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=1100 obj=13.101 num_tokens=20116 num_tokens/piece=18.2873\n",
      "unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=1100 obj=13.0108 num_tokens=20117 num_tokens/piece=18.2882\n",
      "trainer_interface.cc(604) LOG(INFO) Saving model: cnn_dailymail.model\n",
      "trainer_interface.cc(615) LOG(INFO) Saving vocabs: cnn_dailymail.vocab\n"
     ]
    }
   ],
   "source": [
    "from text_data import CNNDatasetWrapper\n",
    "\n",
    "class Wrapper(CNNDatasetWrapper):\n",
    "    split_lengths = [TRAIN_SIZE, math.floor(TRAIN_SIZE * .1), 100]\n",
    "    x_length = 15\n",
    "    target_length = 15\n",
    "\n",
    "wrapper = Wrapper(SP_VOCAB_SIZE, DEVICE)\n",
    "\n",
    "datasets = wrapper.generate_datasets(BATCH_SIZE)\n",
    "train = datasets[\"train\"]\n",
    "valid = datasets[\"validation\"]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_units, hidden_units, output_units):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.input_units = input_units\n",
    "        self.hidden_units = hidden_units\n",
    "\n",
    "        k = math.sqrt(1/hidden_units)\n",
    "        self.input_weight = nn.Parameter(torch.rand(input_units, hidden_units) * 2 * k - k)\n",
    "\n",
    "        self.hidden_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)\n",
    "        self.hidden_bias = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)\n",
    "\n",
    "        self.output_weight = nn.Parameter(torch.rand(hidden_units, output_units) * 2 * k - k)\n",
    "        self.output_bias = nn.Parameter(torch.rand(1, output_units) * 2 * k - k)\n",
    "\n",
    "    def forward(self, x, prev_hidden):\n",
    "        # Compute the regular RNN forward pass\n",
    "        # Input times weights\n",
    "        input_x = x @ self.input_weight\n",
    "        # Sum input with previous hidden state, and add nonlinearity\n",
    "        # Tanh prevents gradients exploding\n",
    "        hidden_x = torch.tanh(input_x + prev_hidden @ self.hidden_weight + self.hidden_bias)\n",
    "\n",
    "        # Compute output vector\n",
    "        output_y = hidden_x @ self.output_weight + self.output_bias\n",
    "        return hidden_x, output_y"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, input_units, hidden_units, output_units=wrapper.vocab_size):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.input_units = input_units\n",
    "        self.hidden_units = hidden_units\n",
    "        self.output_units = output_units\n",
    "\n",
    "        k = math.sqrt(1/hidden_units)\n",
    "        self.hidden_attention_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)\n",
    "        self.context_attention_weight = nn.Parameter(torch.rand(input_units, hidden_units) * 2 * k - k)\n",
    "        self.attention_weight = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)\n",
    "\n",
    "        self.context_hidden_weight = nn.Parameter(torch.rand(hidden_units * 2, hidden_units) * 2 * k - k)\n",
    "        self.hidden_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)\n",
    "        self.hidden_bias = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)\n",
    "\n",
    "        self.output_weight = nn.Parameter(torch.rand(hidden_units, output_units) * 2 * k - k)\n",
    "        self.output_bias = nn.Parameter(torch.rand(1, output_units) * 2 * k - k)\n",
    "\n",
    "    def forward(self, prev_y, prev_hidden, context):\n",
    "        # Compute attention between the encoder hidden states and the previous decoder hidden state\n",
    "        context_attns = context @ self.context_attention_weight\n",
    "        cross = torch.tanh(context_attns + prev_hidden @ self.hidden_attention_weight)\n",
    "        attention = cross @ self.attention_weight.T\n",
    "\n",
    "        # Compute probability for each encoder hidden state, and use it to weight and sum the states\n",
    "        probs = torch.softmax(attention, 1).reshape(context.shape[0])\n",
    "        positional_context = torch.sum(torch.diag(probs) @ context, dim=0).reshape(1, self.input_units)\n",
    "\n",
    "        # Compute a regular rnn.  Cat the context vector and the previous y state.\n",
    "        input_x = torch.cat([positional_context, prev_y], dim=1) @ self.context_hidden_weight\n",
    "        hidden_x = torch.tanh(input_x + prev_hidden @ self.hidden_weight + self.hidden_bias)\n",
    "\n",
    "        output_y = hidden_x @ self.output_weight + self.output_bias\n",
    "        return hidden_x, output_y"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "outputs": [],
   "source": [
    "class Network(nn.Module):\n",
    "    def __init__(self, in_sequence_len, out_sequence_len, hidden_units=512, embedding_len=wrapper.vocab_size):\n",
    "        super(Network, self).__init__()\n",
    "        self.in_sequence_len = in_sequence_len\n",
    "        self.out_sequence_len = out_sequence_len\n",
    "        self.hidden_units = hidden_units\n",
    "        self.embedding_len = embedding_len\n",
    "\n",
    "        self.embedding = nn.Embedding(embedding_len, hidden_units)\n",
    "        self.encoder = Encoder(input_units=hidden_units, hidden_units=hidden_units, output_units=embedding_len)\n",
    "\n",
    "        self.decoder = Decoder(input_units=hidden_units, hidden_units=hidden_units, output_units=embedding_len)\n",
    "\n",
    "    def forward(self, x, y):\n",
    "        embedded = self.embedding(x)\n",
    "\n",
    "        # Encode the input sequence\n",
    "        enc_hiddens = torch.zeros((1, self.hidden_units))\n",
    "        enc_outputs = torch.zeros((1, self.embedding_len))\n",
    "        for j in range(self.in_sequence_len):\n",
    "            hidden, output = self.encoder(embedded[j,:], enc_hiddens[j])\n",
    "            enc_hiddens = torch.cat((enc_hiddens, hidden), dim=0)\n",
    "            enc_outputs = torch.cat((enc_outputs, output), dim=0)\n",
    "\n",
    "        # Decode to the output sequence\n",
    "        # Pass in context\n",
    "        context = enc_hiddens[1:,:]\n",
    "        dec_hiddens = torch.zeros(1, self.hidden_units)\n",
    "        dec_outputs = torch.zeros((1, self.embedding_len))\n",
    "        for j in range(self.out_sequence_len):\n",
    "            # Use either the actual previous y (from the input), or the generated y if the input sequence is shorter than the generation steps.\n",
    "            if y.shape[0] > j:\n",
    "                prev_y = y[j]\n",
    "            else:\n",
    "                prev_y = dec_outputs[j,:]\n",
    "                prev_y = prev_y.argmax(dim=1).int()\n",
    "\n",
    "            prev_y = prev_y.unsqueeze(0)\n",
    "            prev_y = self.embedding(prev_y)\n",
    "            hidden, output = self.decoder(prev_y, dec_hiddens[j,:], context)\n",
    "            dec_hiddens = torch.cat((dec_hiddens, hidden), dim=0)\n",
    "            dec_outputs = torch.cat((dec_outputs, output), dim=0)\n",
    "\n",
    "        return dec_hiddens[1:], dec_outputs[1:]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "device = torch.device(\"cpu\")\n",
    "model = Network(wrapper.x_length, wrapper.y_length, hidden_units=512).to(device)\n",
    "loss_fn = nn.CrossEntropyLoss(ignore_index=wrapper.pad_token)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "448it [00:24, 18.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 train loss: 5.999296477862766 valid loss: 5.748046959147734\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "448it [00:24, 18.51it/s]\n",
      "448it [00:24, 18.50it/s]\n",
      "448it [00:24, 18.47it/s]\n",
      "448it [00:24, 18.28it/s]\n",
      "448it [00:24, 18.25it/s]\n",
      "448it [00:24, 18.42it/s]\n",
      "448it [00:24, 18.51it/s]\n",
      "448it [00:24, 18.62it/s]\n",
      "448it [00:24, 18.57it/s]\n",
      "448it [00:24, 18.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10 train loss: 3.1196700052491257 valid loss: 5.839749308193431\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "448it [00:24, 18.42it/s]\n",
      "448it [00:23, 18.88it/s]\n",
      "448it [00:22, 19.82it/s]\n",
      "448it [00:22, 19.56it/s]\n",
      "448it [00:22, 19.90it/s]\n",
      "448it [00:22, 20.09it/s]\n",
      "448it [00:22, 19.84it/s]\n",
      "448it [00:23, 19.03it/s]\n",
      "448it [00:27, 16.31it/s]\n",
      "448it [00:27, 16.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20 train loss: 1.3020785356472646 valid loss: 6.145435006010766\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "448it [00:27, 16.51it/s]\n",
      "448it [00:27, 16.55it/s]\n",
      "448it [00:27, 16.25it/s]\n",
      "448it [00:26, 16.72it/s]\n",
      "448it [00:27, 16.35it/s]\n",
      "448it [00:27, 16.28it/s]\n",
      "448it [00:27, 16.28it/s]\n",
      "448it [00:27, 16.25it/s]\n",
      "448it [00:27, 16.27it/s]\n",
      "448it [00:27, 16.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 30 train loss: 0.5017049135806572 valid loss: 6.383245907577813\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "448it [00:27, 16.41it/s]\n",
      "448it [00:27, 16.35it/s]\n",
      "448it [00:27, 16.34it/s]\n",
      "448it [00:27, 16.51it/s]\n",
      "448it [00:27, 16.52it/s]\n",
      "448it [00:27, 16.58it/s]\n",
      "448it [00:27, 16.36it/s]\n",
      "448it [00:27, 16.15it/s]\n",
      "448it [00:27, 16.34it/s]\n",
      "448it [00:26, 16.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 40 train loss: 0.24164675797302543 valid loss: 6.541052986593807\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "448it [00:26, 16.65it/s]\n",
      "448it [00:27, 16.53it/s]\n",
      "448it [00:27, 16.49it/s]\n",
      "448it [00:27, 16.30it/s]\n",
      "448it [00:27, 16.39it/s]\n",
      "448it [00:27, 16.34it/s]\n",
      "448it [00:27, 16.30it/s]\n",
      "448it [00:27, 16.19it/s]\n",
      "448it [00:27, 16.23it/s]\n",
      "448it [00:27, 16.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50 train loss: 0.14702079623072808 valid loss: 6.659904620226691\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "448it [00:27, 16.32it/s]\n",
      "448it [00:27, 16.31it/s]\n",
      "448it [00:27, 16.09it/s]\n",
      "448it [00:27, 16.49it/s]\n",
      "448it [00:27, 16.45it/s]\n",
      "448it [00:27, 16.35it/s]\n",
      "448it [00:27, 16.01it/s]\n",
      "448it [00:28, 15.72it/s]\n",
      "448it [00:28, 15.88it/s]\n",
      "448it [00:28, 15.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 60 train loss: 0.10320704718469642 valid loss: 6.745182869481106\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "448it [00:25, 17.71it/s]\n",
      "448it [00:22, 19.52it/s]\n",
      "448it [00:23, 19.48it/s]\n",
      "448it [00:23, 19.46it/s]\n",
      "448it [00:23, 19.32it/s]\n",
      "448it [00:23, 19.27it/s]\n",
      "448it [00:23, 19.39it/s]\n",
      "448it [00:23, 19.28it/s]\n",
      "448it [00:23, 19.35it/s]\n",
      "448it [00:23, 19.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 70 train loss: 0.07870655508512366 valid loss: 6.816830148883894\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "448it [00:23, 19.32it/s]\n",
      "448it [00:23, 19.20it/s]\n",
      "448it [00:23, 19.17it/s]\n",
      "448it [00:23, 18.93it/s]\n",
      "448it [00:23, 18.89it/s]\n",
      "448it [00:23, 18.97it/s]\n",
      "448it [00:23, 18.81it/s]\n",
      "448it [00:24, 18.62it/s]\n",
      "448it [00:23, 18.75it/s]\n",
      "260it [00:13, 18.95it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[28], line 12\u001B[0m\n\u001B[1;32m     10\u001B[0m target \u001B[38;5;241m=\u001B[39m target\u001B[38;5;241m.\u001B[39mreshape(wrapper\u001B[38;5;241m.\u001B[39my_length)\n\u001B[1;32m     11\u001B[0m loss \u001B[38;5;241m=\u001B[39m loss_fn(pred, target)\n\u001B[0;32m---> 12\u001B[0m \u001B[43mloss\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     13\u001B[0m optimizer\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m     14\u001B[0m train_loss \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m loss\u001B[38;5;241m.\u001B[39mitem()\n",
      "File \u001B[0;32m~/.virtualenvs/nnets/lib/python3.10/site-packages/torch/_tensor.py:488\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m    478\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m    479\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m    480\u001B[0m         Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m    481\u001B[0m         (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m    486\u001B[0m         inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m    487\u001B[0m     )\n\u001B[0;32m--> 488\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    489\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m    490\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/.virtualenvs/nnets/lib/python3.10/site-packages/torch/autograd/__init__.py:197\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m    192\u001B[0m     retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m    194\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m    195\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m    196\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 197\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m  \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m    198\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    199\u001B[0m \u001B[43m    \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "EPOCHS = 1000\n",
    "for epoch in range(EPOCHS):\n",
    "    # Run over the training examples\n",
    "    train_loss = 0\n",
    "    for batch, (sequence, target, prev_target) in tqdm(enumerate(train)):\n",
    "        optimizer.zero_grad()\n",
    "        hidden, pred = model(sequence[0,:], prev_target[0,:])\n",
    "\n",
    "        pred = pred.reshape(wrapper.y_length, wrapper.vocab_size)\n",
    "        target = target.reshape(wrapper.y_length)\n",
    "        loss = loss_fn(pred, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss += loss.item()\n",
    "\n",
    "    if epoch % 10 == 0:\n",
    "        # Compute validation loss.  Unless you have a lot of training data, it won't be able to generalize.\n",
    "        valid_loss = 0\n",
    "        with torch.no_grad():\n",
    "            for batch, (sequence, target, prev_target) in enumerate(valid):\n",
    "                # Only feed in the first token of the actual target\n",
    "                hidden, pred = model(sequence[0,:], prev_target[0,:])\n",
    "                pred = pred.reshape(wrapper.y_length, wrapper.vocab_size)\n",
    "                target = target.reshape(wrapper.y_length)\n",
    "                loss = loss_fn(pred, target)\n",
    "                valid_loss += loss.item()\n",
    "\n",
    "        print(f\"Epoch {epoch} train loss: {train_loss / len(train)} valid loss: {valid_loss / len(valid)}\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
